// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using NuGet.Packaging.Core;
using NuGet.Protocol.Core.Types;
using NuGet.Protocol.Model;
using NuGet.Versioning;
using NuGet.VisualStudio.Internal.Contracts;

namespace NuGet.PackageManagement.VisualStudio
{
    public class PackageVulnerabilityService : IPackageVulnerabilityService
    {
        private readonly IEnumerable<SourceRepository> _sourceRepositories;
        private GetVulnerabilityInfoResult _vulnerabilities;
        private INuGetUILogger _logger;

        public PackageVulnerabilityService(IEnumerable<SourceRepository> sourceRepositories, INuGetUILogger logger)
        {
            _sourceRepositories = sourceRepositories ?? throw new ArgumentNullException(nameof(sourceRepositories));
            _logger = logger ?? throw new ArgumentNullException(nameof(logger)); ;
        }

        public async Task<List<PackageVulnerabilityMetadataContextInfo>> GetVulnerabilityInfoAsync(PackageIdentity packageId, CancellationToken cancellationToken)
        {
            _vulnerabilities = await GetAllVulnerabilityDataAsync(cancellationToken);
            IEnumerable<PackageVulnerabilityInfo> packageVulnerabilities = Enumerable.Empty<PackageVulnerabilityInfo>();

            if (_vulnerabilities?.Exceptions is not null)
            {
                ReplayErrors(_vulnerabilities.Exceptions);
            }

            IReadOnlyList<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> allVulnerabilities = _vulnerabilities?.KnownVulnerabilities;
            if (allVulnerabilities is not null && allVulnerabilities.Any())
            {
                packageVulnerabilities = GetKnownVulnerabilities(packageId.Id, packageId.Version, allVulnerabilities);
            }

            return ConvertToPackageVulnerabilityMetadataContextInfo(packageVulnerabilities);
        }

        // Copied and adapted from NuGet.Commands.Restore.Utility.AuditUtility.GetAllVulnerabilityDataAsync
        private async Task<GetVulnerabilityInfoResult> GetAllVulnerabilityDataAsync(CancellationToken cancellationToken)
        {
            IEnumerable<Task<GetVulnerabilityInfoResult>> results = _sourceRepositories.Select(sr => sr.GetVulnerabilityInfoAsync(cancellationToken));
            await Task.WhenAll(results);

            if (cancellationToken.IsCancellationRequested)
            {
                cancellationToken.ThrowIfCancellationRequested();
            }

            List<Exception> errors = null;
            List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities = null;
            foreach (var resultTask in results)
            {
                GetVulnerabilityInfoResult result = await resultTask;
                if (result is null) continue;

                if (result.KnownVulnerabilities != null)
                {
                    knownVulnerabilities ??= new();
                    knownVulnerabilities.AddRange(result.KnownVulnerabilities);
                }

                if (result.Exceptions != null)
                {
                    if (errors == null)
                    {
                        errors = new();
                    }

                    errors.AddRange(result.Exceptions.InnerExceptions);
                }
            }

            GetVulnerabilityInfoResult final =
                knownVulnerabilities != null || errors != null
                ? new(knownVulnerabilities, errors != null ? new AggregateException(errors) : null)
                : null;
            return final;
        }

        // Copied from NuGet.Commands.Restore.Utility.AuditUtility.GetKnownVulnerabilities
        private static List<PackageVulnerabilityInfo> GetKnownVulnerabilities(
            string name,
            NuGetVersion version,
            IReadOnlyList<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities)
        {
            HashSet<PackageVulnerabilityInfo> vulnerabilities = null;

            if (knownVulnerabilities == null) return null;

            foreach (var file in knownVulnerabilities)
            {
                if (file.TryGetValue(name, out var packageVulnerabilities))
                {
                    foreach (var vulnInfo in packageVulnerabilities)
                    {
                        if (vulnInfo.Versions.Satisfies(version))
                        {
                            if (vulnerabilities == null)
                            {
                                vulnerabilities = new();
                            }
                            vulnerabilities.Add(vulnInfo);
                        }
                    }
                }
            }

            return vulnerabilities != null ? vulnerabilities.ToList() : null;
        }

        private List<PackageVulnerabilityMetadataContextInfo> ConvertToPackageVulnerabilityMetadataContextInfo(IEnumerable<PackageVulnerabilityInfo> packageVulnerabilities)
        {
            if (packageVulnerabilities == null || !packageVulnerabilities.Any())
            {
                return new List<PackageVulnerabilityMetadataContextInfo>(0);
            }

            return packageVulnerabilities?.Select(pvi => new PackageVulnerabilityMetadataContextInfo(pvi.Url, (int)pvi.Severity)).ToList();
        }

        private void ReplayErrors(AggregateException exceptions)
        {
            foreach (Exception exception in exceptions.InnerExceptions)
            {
                _logger.Log(ProjectManagement.MessageLevel.Warning, Strings.Error_VulnerabilityDataFetch, exception.Message);
            }
        }
    }
}
